// Copyright (c) 2023 Huawei Technologies Co., Ltd
// Copyright (c) 2019, Facebook CORPORATION.
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "op_plugin/OpApiInterface.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/op_api_common.h"

namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;

namespace{
const int64_t INT4_NUMS_IN_INT32_SPACE = 8;
const int64_t TYPE_NUM_FOR_HIFLOAT8 = 290;

static std::map<int64_t, at::ScalarType> QUANTIZE_SUPPORT_MAP = {
    {static_cast<int64_t>(at::kQUInt8), at::ScalarType::Byte},
    {static_cast<int64_t>(at::kQInt8), at::ScalarType::Char},
    {static_cast<int64_t>(at::kQInt32), at::ScalarType::Int},
    {static_cast<int64_t>(at::kByte), at::ScalarType::Byte},
    {static_cast<int64_t>(at::kChar), at::ScalarType::Char},
    {static_cast<int64_t>(at::kInt), at::ScalarType::Int},
    {static_cast<int64_t>(at::kFloat8_e4m3fn), at::ScalarType::Float8_e4m3fn},
    {static_cast<int64_t>(at::kFloat8_e5m2), at::ScalarType::Float8_e5m2}, 
    {TYPE_NUM_FOR_HIFLOAT8, at::ScalarType::Byte}};

static std::initializer_list<int64_t> ASCEND_QUANT_V2_SUPPORT_LIST = {
    static_cast<int64_t>(at::kChar),
    static_cast<int64_t>(at::kQInt8),
    static_cast<int64_t>(at::ScalarType::QUInt4x2),
    static_cast<int64_t>(at::kFloat8_e4m3fn),
    static_cast<int64_t>(at::kFloat8_e5m2),
    TYPE_NUM_FOR_HIFLOAT8};
};

at::Tensor npu_quantize_by_kernel(
    const at::Tensor& self,
    const at::Tensor& scales,
    const c10::optional<at::Tensor>& zero_points_opt,
    int64_t dtype,
    int64_t axis)
{
    // check if aclnn api implemented
    static auto quantizeFuncAddr = GetOpApiFuncAddr("aclnnQuantize");
    static auto quantizeGetWorkspaceSizeAddr = GetOpApiFuncAddr("aclnnQuantizeGetWorkspaceSize");
    if (quantizeFuncAddr == nullptr || quantizeGetWorkspaceSizeAddr == nullptr) {
        return acl_op::npu_quantize(self, scales, zero_points_opt, dtype, axis);
    }
    // check output datatype supported
    TORCH_CHECK(QUANTIZE_SUPPORT_MAP.find(dtype) != QUANTIZE_SUPPORT_MAP.end(),
        "Param (dtype) must be Int8, UInt8, Int32, HiFloat8, Float8_e4m3fn, Float8_e5m2" + OPS_ERROR(ErrCode::TYPE));
    auto output_shape = op_infer::array_to_small_vector(self.sizes());

    at::ScalarType scalarDtype = QUANTIZE_SUPPORT_MAP[dtype];  
    aclDataType yAclType = npu_preparation::convert_to_acl_data_type(scalarDtype);
    if (dtype == TYPE_NUM_FOR_HIFLOAT8) {
        yAclType = ACL_HIFLOAT8;
    }    
    at::Tensor y = npu_preparation::apply_tensor_without_format(output_shape, self.options().dtype(scalarDtype));
    TensorWrapper y_wrapper = {y, yAclType};
    EXEC_NPU_CMD(aclnnQuantize, self, scales, zero_points_opt, yAclType, axis, y_wrapper);
    return y;
};

at::Tensor npu_quantize_by_ascend_quant(
    const at::Tensor& self,
    const at::Tensor& scales,
    const c10::optional<at::Tensor>& zero_points_opt,
    int64_t dtype,
    int64_t axis)
{
    at::ScalarType scalarDtype = at::ScalarType::Undefined;
    aclDataType yAclType = ACL_INT8;
    at::Tensor result;

    TORCH_CHECK(std::find(ASCEND_QUANT_V2_SUPPORT_LIST.begin(), ASCEND_QUANT_V2_SUPPORT_LIST.end(), dtype) !=
        ASCEND_QUANT_V2_SUPPORT_LIST.end(),
        "Param (dtype) must be Int8, QInt8, QUInt4x2, HiFloat8, Float8_e4m3fn, Float8_e5m2" + OPS_ERROR(ErrCode::TYPE));

    if (dtype == static_cast<int64_t>(at::kQInt8)) {
        yAclType = ACL_INT8;
        scalarDtype = at::ScalarType::Char;
    } else if (dtype == static_cast<int64_t>(at::ScalarType::QUInt4x2)) {
        // int4 pack to int32
        yAclType = ACL_INT32;
        scalarDtype = at::ScalarType::Int;
    } else {
        yAclType = torch_npu::te::GetAclDataType(dtype);
        scalarDtype = npu_preparation::convert_to_scalar_type(yAclType);
    }

    if (scalarDtype == at::ScalarType::Int) {
        auto output_shape = op_infer::array_to_small_vector(self.sizes());
        auto x_dim_num = self.dim();
        TORCH_CHECK(output_shape[x_dim_num - 1] % INT4_NUMS_IN_INT32_SPACE == 0,
                    "Input shape last dim must be divded by 8" + OPS_ERROR(ErrCode::PARAM));
        output_shape[x_dim_num - 1] /= INT4_NUMS_IN_INT32_SPACE;
        result = npu_preparation::apply_tensor_without_format(output_shape, self.options().dtype(scalarDtype));
    } else {
        result = npu_preparation::apply_tensor(self, self.options().dtype(scalarDtype));
    }
    TensorWrapper y_wrapper = {result, yAclType};
    const bool sqrt_mode = false;
    static auto opApiFuncAddr = GetOpApiFuncAddr("aclnnAscendQuantV3");
    static auto opApiGetWorkspaceSizeFuncAddr = GetOpApiFuncAddr("aclnnAscendQuantV3GetWorkspaceSize");
    if (opApiFuncAddr == nullptr || opApiGetWorkspaceSizeFuncAddr == nullptr) {
        EXEC_NPU_CMD(aclnnAscendQuant, self, scales, zero_points_opt, sqrt_mode, "round", yAclType, y_wrapper);
    } else {
        axis = axis < -1 ? axis : -1;
        EXEC_NPU_CMD(aclnnAscendQuantV3, self, scales, zero_points_opt, sqrt_mode, "round", yAclType, axis, y_wrapper);
    }
    return result;
};

at::Tensor npu_quantize(
    const at::Tensor& self,
    const at::Tensor& scales,
    const c10::optional<at::Tensor>& zero_points_opt,
    int64_t dtype,
    int64_t axis,
    bool div_mode)
{
    if (div_mode) {
        return npu_quantize_by_kernel(self, scales, zero_points_opt, dtype, axis);
    }
    return npu_quantize_by_ascend_quant(self, scales, zero_points_opt, dtype, axis);
};

} // namespace op_api
